package com.xinxin.datastructure.tree;

/**
 * @author 史鑫鑫
 * @date Created in 2019/6/12 21:29
 */
@SuppressWarnings("unchecked")
public class LinkedListSegmentTree<E> {
    /**
     * 节点内部类
     */
    private class Node {
        E e;
        int l;
        int r;
        Node left;
        Node right;

        Node(E e, int l, int r) {
            this.e = e;
            this.l = l;
            this.r = r;
            this.left = null;
            this.right = null;
        }

        Node(int l, int r) {
            this(null, l, r);
        }

        @Override
        public String toString() {
            return "Node{" +
                    "e=" + e +
                    ", l=" + l +
                    ", r=" + r +
                    '}';
        }
    }

    /**
     * 原数组
     */
    private E[] data;

    /**
     * 线段树
     */
    private Node root;

    /**
     * 合并规则
     */
    private Merger<E> merger;

    /**
     * 利用一个数组个合并规则创建一棵线段树
     *
     * @param arr    原始数组
     * @param merger 合并规则
     */
    public LinkedListSegmentTree(E[] arr, Merger<E> merger) {
        this.merger = merger;
        data = (E[]) new Object[arr.length];
        System.arraycopy(arr, 0, data, 0, arr.length);
        root = buildSegmentTree(0, data.length - 1);
    }

    /**
     * 构建l到r之间的线段树，并返回线段树的根节点
     *
     * @param l 左边界
     * @param r 右边界
     * @return 线段树根节点
     */
    private Node buildSegmentTree(int l, int r) {
        if (l == r) {
            return new Node(data[l], l, r);
        }
        int mid = l + (r - l) / 2;
        Node node = new Node(l, r);
        node.left = buildSegmentTree(l, mid);
        node.right = buildSegmentTree(mid + 1, r);
        node.e = merger.merge(node.left.e, node.right.e);
        return node;
    }

    /**
     * 返回区间内的元素个数
     *
     * @return 区间内的元素个数
     */
    public int getSize() {
        return data.length;
    }

    /**
     * 获取index索引位置的元素
     *
     * @param index 索引
     * @return 该索引位置的元素
     */
    public E get(int index) {
        if (index < 0 || index >= data.length) {
            throw new IllegalArgumentException("Index is illegal.");
        }
        return data[index];
    }

    /**
     * 返回区间[queryL, queryR]的值
     *
     * @param queryL 左边界
     * @param queryR 右边界
     * @return 合并值
     */
    public E query(int queryL, int queryR) {
        // 参数不合法
        if (queryL < 0 || queryL >= data.length ||
                queryR < 0 || queryR >= data.length ||
                queryL > queryR) {
            throw new IllegalArgumentException("Index is illegal.");
        }
        return query(root, 0, data.length - 1, queryL, queryR);
    }

    /**
     * 在以node为根的线段树中[l...r]的范围里，搜索区间[queryL, queryR]的值
     *
     * @param node   根节点索引
     * @param l      线段树左边界
     * @param r      线段树右边界
     * @param queryL 搜索区间左边界
     * @param queryR 搜索区间右边界
     * @return 搜索值
     */
    private E query(Node node, int l, int r, int queryL, int queryR) {
        if (l == queryL && r == queryR) {
            return node.e;
        }
        int mid = l + (r - l) / 2;
        Node leftNode = node.left;
        Node rightNode = node.right;
        if (queryL >= mid + 1) {
            return query(rightNode, mid + 1, r, queryL, queryR);
        } else if (queryR <= mid) {
            return query(leftNode, l, mid, queryL, queryR);
        } else {
            E leftResult = query(leftNode, l, mid, queryL, mid);
            E rightResult = query(rightNode, mid + 1, r, mid + 1, queryR);
            return merger.merge(leftResult, rightResult);
        }
    }

    /**
     * 将index位置的值，更新为e
     *
     * @param index 索引
     * @param e     更新值
     */
    public void set(int index, E e) {
        if (index < 0 || index >= data.length) {
            throw new IllegalArgumentException("Index is illegal");
        }
        data[index] = e;
        set(root, 0, data.length - 1, index, e);
    }

    /**
     * 在以node为根的线段树中更新index的值为e
     *
     * @param node  根节点索引
     * @param l     左边界
     * @param r     右边界
     * @param index 更新索引值
     * @param e     更新值
     */
    private void set(Node node, int l, int r, int index, E e) {
        if (l == r) {
            node.e = e;
            return;
        }
        int mid = l + (r - l) / 2;
        Node leftNode = node.left;
        Node rightNode = node.right;
        if (index >= mid + 1) {
            set(rightNode, mid + 1, r, index, e);
        } else {
            set(leftNode, l, mid, index, e);
        }
        node.e = merger.merge(leftNode.e, rightNode.e);
    }

    private String toString(Node node) {
        if (node == null) {
            return "null ";
        }
        return node + " " + toString(node.left) + toString(node.right);
    }

    @Override
    public String toString() {
        return toString(root);
    }
}
